#!/usr/bin/python3
#
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2021-2023, ByteDance Ltd. and/or its Affiliates
# Author: Yuanhan Liu <liuyuanhan.131@bytedance.com>
# Author: Tao Liu <liutao.xyz@bytedance.com>
#
# Tao Liu: rewrote in python

import sys
import time
import json
import argparse
import subprocess
import signal
import shutil
from operator import itemgetter

interval, sid_list, active, sort = (1, '', False, 0)
rated_stats_base = ['PKT_RECV', 'BYTE_RECV', 'PKT_XMIT',
		    'BYTE_XMIT', 'PKT_RE_XMIT', 'BYTE_RE_XMIT']
stats_cfg = {
	'PKT_RECV' :     { 'unit': 1<<20, 'width': '%-10s'},
	'BYTE_RECV' :    { 'unit': 1<<20, 'width': '%-10s'},
	'PKT_XMIT' :     { 'unit': 1<<20, 'width': '%-10s'},
	'BYTE_XMIT' :    { 'unit': 1<<20, 'width': '%-10s'},
	'PKT_RE_XMIT' :  { 'unit': 1<<10, 'width': '%-13s'},
	'BYTE_RE_XMIT' : { 'unit': 1<<10, 'width': '%-13s'},
}


def parse_args():
	parser = argparse.ArgumentParser()
	parser.add_argument("-l", "--list", type=str, default='', help="sid list (like: 122,112)")
	parser.add_argument("-i", "--interval", type=int, default=1,  help="interval (default to 1s)")
	parser.add_argument("-a", "--active", action="store_true", default=False, help="show active tsock")
	parser.add_argument("-s", "--sort", type=int, default=0, choices=range(1, 7),
		            help="1: rx.mpps, 2: rx.MB/s, 3: tx.mpps, 4: tx.MB/s, 5: retrans.kpps, 6: retrans.KB/s")
	args = parser.parse_args()

	global interval, sid_list, active, sort
	interval, sid_list, active, sort = (args.interval, args.list, args.active, args.sort)

def clear_screan():
	print("\x1b[2J\x1b[1;1H", end = '')

def uniform(n, unit, time_detla):
	res = round(float(n) / unit * 1000 / float(time_detla)) / 1000
	return 0 if res == 0.0 else res

def cmd_exec(args):
	try:
		p = subprocess.check_output(args)
	except subprocess.CalledProcessError as e:
		sys.exit('error exec cmd: %s' + str(e.returncode))
	return p.decode('utf-8')

def sock_stats_parse():
	args = ['tpa', 'sock-list', '-v', '-j']
	if sid_list != "":
		args.append(' '.join(sid_list.split(',')))

	try:
		stdout = cmd_exec(args)
		if len(stdout) == 0:
			raise Exception("zero output cmd: %s" % args)
		tsocks = json.loads(stdout)
		if len(tsocks) == 0:
			raise Exception("failed to load json output")
	except Exception as e:
		sys.exit("catch error: {}".format(e))

	tsock_map = {}
	for id, tsock in enumerate(tsocks):
		tsock_map[tsock['sid']] = id

	return tsocks, tsock_map

def find_last_tsock(last, tsock_map, sid):
	if sid in tsock_map:
		return last[tsock_map[sid]]
	return {}

def calc_rated_stats(delta, tsock, last_tsock):
	idle_sock = True

	for i in rated_stats_base:
		if i not in tsock:
			delta[i] = 0
			continue
		delta[i] = tsock[i] - last_tsock[i] if i in last_tsock else tsock[i]
		if delta[i] != 0:
			idle_sock = False

	return idle_sock

def sock_stats_diff(curr, last, tsock_map):
	result = []

	for _, tsock in enumerate(curr):
		last_tsock = find_last_tsock(last, tsock_map, tsock["sid"])
		if not last_tsock:
			continue

		delta = {}
		idle_sock = calc_rated_stats(delta, tsock, last_tsock)
		if active and idle_sock:
			continue

		delta['sid'] = tsock['sid']
		delta['state'] = tsock['state']
		delta['connection'] = tsock['connection']
		result.append(delta)
	if sort > 0:
		result = sorted(result, key = itemgetter(rated_stats_base[sort - 1]), reverse = True)

	return result

def calc_total_and_format(delta):
	total = {'sid': 'total', 'state': str(len(delta)), 'connection': '-'}
	total.update({key: 0 for key in rated_stats_base})

	for _, tsock in enumerate(delta):
		for i in rated_stats_base:
			total[i] = total[i] + tsock[i]
	delta.append(total)

	fmt_str = ''.join([stats_cfg[i]['width'] for i in rated_stats_base])
	return '%-7s%-13s' + fmt_str + '%s'

def sock_stats_print(delta, time_delta):
	fmt_str = calc_total_and_format(delta)
	heigh = shutil.get_terminal_size().lines

	clear_screan()
	print(fmt_str % ('sid', 'state', 'rx.mpps', 'rx.MB/s', 'tx.mpps',
			 'tx.MB/s', 'retrans.kpps', 'retrans.KB/s', 'connection'))

	size = len(delta) - 1
	for id, tsock in enumerate(delta):
		if id != size and heigh > 0 and id >= heigh - 2:
			continue
		print_end = '\n' if id < size else ''
		print(fmt_str % (tuple([str(tsock['sid']), tsock['state']] +
				[str(uniform(tsock[i], stats_cfg[i]['unit'], \
						time_delta)) for i in rated_stats_base] +
				[tsock['connection']])), end = print_end)

	sys.stdout.flush()

def sig_handler(s, f):
	print()
	sys.exit(0)

def sock_stats_show():
	parse_args()
	signal.signal(signal.SIGINT, sig_handler)

	last, last_tsock_map = ([], {})
	last_time, delta_time = (time.time(), interval)
	while True:
		curr, curr_tsock_map = sock_stats_parse()
		if not curr:
			sys.exit("failed to parse sock stats")

		if last:
			delta = sock_stats_diff(curr, last, last_tsock_map)
			sock_stats_print(delta, delta_time)

		last, last_tsock_map = (curr, curr_tsock_map)
		time.sleep(interval)

		curr_time = time.time()
		delta_time, last_time = ((curr_time - last_time), curr_time)

if __name__ == "__main__":
	sock_stats_show()
